import copy
import glob
import os
import time
from collections import deque
import pickle

import gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from a2c_ppo_acktr import algo, utils
from a2c_ppo_acktr.envs import make_vec_envs
from a2c_ppo_acktr.model import Policy
from a2c_ppo_acktr.algo.gail import ExpertDataset
from a2c_ppo_acktr.arguments import get_args
from a2c_ppo_acktr.storage import RolloutStorage
from evaluation import evaluate
from CNF import *
from datetime import datetime
import wandb


def funcpath():
    print(f"Loaded func from {args.funcpoint}")
    return funcpoint

args = get_args()

if args.adjoint:
    from torchdiffeq import odeint_adjoint as odeint
else:
    from torchdiffeq import odeint

torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)

torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
    
# record training log
log_dir = os.path.expanduser(args.log_dir)
eval_log_dir = log_dir + "_eval"
utils.cleanup_log_dir(log_dir)
utils.cleanup_log_dir(eval_log_dir)

torch.set_num_threads(1)
device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu')

tags = ['Imitation Learning', 'NeuralODE', f'Game_{args.env_name}with{args.num_demo}trajs']
if args.log_wandb:
    wandb.init(name=f"({args.env_name},{args.num_demo})sep{args.buffer_num}",
               project=f"NoiseNFIL_{args.env_name}",
               tags=tags)
    wandb.config.update(args)
datetime_now = datetime.now().strftime("%Y%m%d-%H%M%S")
args.save_dir = os.path.join(args.save_dir, "NFIL", f"{args.env_name}", str(datetime_now))

t0 = 0        
t1 = 10
# model
envs = make_vec_envs(args.env_name, args.seed, args.num_processes, args.gamma, args.log_dir, device, True)

func = CNF(in_out_dim=envs.observation_space.shape[0], hidden_dim=args.hidden_dim, width=args.width).to(device)
funcpoint = torch.load(funcpath())
func.load_state_dict(funcpoint['func_state_dict'])

# CartPole * 2, Reacher * .3 HalfCheetah 20
p_z0 = torch.distributions.MultivariateNormal(
    loc=torch.tensor(np.zeros(envs.observation_space.shape[0])).to(device),
    covariance_matrix=torch.tensor(np.identity(envs.observation_space.shape[0])).to(device))
# func_optimizer = optim.Adam(func.parameters(), lr=4e-3)#lr=args.lr)

actor = Policy(envs.observation_space.shape, envs.action_space, 
                      base_kwargs={'recurrent': args.recurrent_policy})
actor.to(device)

agent = algo.PPO(actor, 0.2, 4, 32, 0.5, 0.01, lr=args.lr, eps=1e-5, max_grad_norm=0.1)

if args.save_dir is not None:
    if not os.path.exists(os.path.join(args.save_dir, str(datetime_now))):
        os.makedirs(args.save_dir)
    ckpt_path = os.path.join(args.save_dir, f'{args.env_name}_ckpt.pth')
    if os.path.exists(ckpt_path):
        checkpoint = torch.load(ckpt_path)
        actor.load_state_dict(checkpoint['actor_state_dict'])
        agent.optimizer.load_state_dict(checkpoint['actor_optimizer_state_dict'])
        print('Loaded ckpt from {}'.format(ckpt_path))

# get expert demonstration
file_name = os.path.join(args.experts_dir, "{}.h5".format(args.env_name))
expert_dataset = ExpertDataset(file_name, num_trajectories=args.num_demo, subsample_frequency=args.subsample_frequency)

drop_last = len(expert_dataset) > args.batch_size
train_loader = torch.utils.data.DataLoader(dataset=expert_dataset, batch_size=100, #args.batch_size,
                                           shuffle=True, drop_last=drop_last)

for i, (states, actions, seqs) in enumerate(train_loader):
    if i == 0:
        minvalues = torch.min(states, axis=0).values
        maxvalues = torch.max(states, axis=0).values
    minvalues = torch.min(minvalues, torch.min(states, axis=0).values)
    maxvalues = torch.max(minvalues, torch.max(states, axis=0).values)
minvalues = minvalues.to(device)
maxvalues = maxvalues.to(device)

rollouts = RolloutStorage(args.max_step_num*args.buffer_num, args.num_processes,
                          envs.observation_space.shape, envs.action_space,
                          actor.recurrent_hidden_state_size)

obs = envs.reset()
rollouts.obs[0].copy_(obs)
rollouts.to(device)

episode_rewards = deque(maxlen=10)

start = time.time()
num_updates = int(10e6 // args.max_step_num // args.num_processes)

bcstart = 6 if args.num_demo > 4 else 0 
            
for j in range(num_updates):
    # schedule learning rate
#     utils.update_linear_schedule(agent.optimizer, j, num_updates, args.lr)
    utils.update_step_schedule(agent.optimizer, j, num_updates, args.lr)
                    
    '''Compute policy loss'''
    print('Training Policy in BC ... ')
    if j > bcstart:
#         BCloss = 0
        for i, (states, actions, seqs) in enumerate(train_loader):
            _, bcaction_log_probs, _, _ = actor.evaluate_actions(
                        states.to(device), torch.zeros(states.size()[0], actor.recurrent_hidden_state_size).to(device), 
                        torch.ones(states.size()[0], 1).to(device), actions.to(device))
            BCloss = - bcaction_log_probs.mean(0)
            
            agent.optimizer.zero_grad()
            BCloss.backward()
            nn.utils.clip_grad_norm_(agent.actor_critic.parameters(), agent.max_grad_norm)
            agent.optimizer.step()

        if args.log_wandb:
            wandb.log({"BC loss": BCloss.item(), "epoch": j})
    
    '''Generate trajectories ... '''
    for step in range(args.max_step_num):
        # Sample actions
        with torch.no_grad():
            value, action, action_log_prob, recurrent_hidden_states = actor.act(
                rollouts.obs[step], rollouts.recurrent_hidden_states[step],
                rollouts.masks[step])

        # Obser reward and next obs
        obs, reward, done, infos = envs.step(action)

        for info in infos:
            if 'episode' in info.keys():
                episode_rewards.append(info['episode']['r'])

        # If done then clean the history of observations.
        masks = torch.FloatTensor(
            [[0.0] if done_ else [1.0] for done_ in done])
        bad_masks = torch.FloatTensor(
            [[0.0] if 'bad_transition' in info.keys() else [1.0]
             for info in infos])
        rollouts.insert(obs, recurrent_hidden_states, action,
                        action_log_prob, value, reward, masks, bad_masks)

    with torch.no_grad():
        next_value = actor.get_value(
            rollouts.obs[-1], rollouts.recurrent_hidden_states[-1],
            rollouts.masks[-1]).detach()

    if j <= bcstart or j%args.buffer_num == 1:
        with torch.no_grad():
            logp_diff_t1 = torch.zeros(rollouts.obs.size()[0], 1).type(torch.float32).to(device)
            z_t, logp_diff_t = odeint(    # the ODEsolver
                func,
                ((rollouts.obs.view(-1, rollouts.obs.size()[-1])-minvalues)/(maxvalues-minvalues), logp_diff_t1), # noise
                torch.tensor([t1, t0]).type(torch.float32).to(device),
                atol=1e-5,
                rtol=1e-5,
                method='dopri5')
            z_t0, logp_diff_t0 = z_t[-1], logp_diff_t[-1]
            logp_x = p_z0.log_prob(z_t0).view(-1).to(device) - logp_diff_t0.view(-1)

        for step in range(args.max_step_num):
            rollouts.rewards[step] = -logp_x[step].clamp(1e-8, 2)

        rollouts.compute_returns(next_value, True, 0.99, 0.95, True)

        value_loss, action_loss, dist_entropy = agent.update(rollouts)

        rollouts.after_update()

    '''Evaluate and save model'''
    evaluate(actor, args.env_name, args.seed, args.num_processes, args.eval_log_dir, device, args, j)
    if args.log_wandb:
        wandb.log({"Action loss": action_loss, "Value loss": value_loss, 
                   "epoch": j})
#     print('Policy Iter: {}, accumulated loss: {:.4f}'.format(j, BCloss.item()))
    if j % 10 == 0:
        ckpt_path = os.path.join(args.save_dir, '{}_{}_ckpt.pth'.format(args.env_name, j))
        torch.save({
            'actor_state_dict': actor.state_dict(),
            'actor_optimizer_state_dict': agent.optimizer.state_dict(),
        }, ckpt_path)
        print('Stored ckpt at {}'.format(ckpt_path))
